import torch

### set seeds
torch.manual_seed(0)
import numpy as np

np.random.seed(0)
import random

random.seed(0)

### clean warnings
import warnings

warnings.filterwarnings("ignore")

import argparse

from src.optimizer_DLRT.dlrt_optimizer import DLRT_Optimizer
from src.training import trainer
from src.training import profiler_trainer

from src.datasets.dataset_utils import choose_dataset
from src.models.model_utils import choose_model


def main():
    ###################### parser creation  ######################
    parser = argparse.ArgumentParser(description='Pytorch TDLRT training')
    # Arguments for network training
    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--epochs', type=int, default=60, metavar='N', help='number of epochs to train (default: 100)')
    parser.add_argument('--epochs_ft', type=int, default=30, metavar='N', help='number of epochs to train (default: 0)')
    parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
                        help='learning rate for dlrt optimizer (default: 0.05)')
    parser.add_argument('--tau', type=float, default=0.2, metavar='tau',
                        help='cutting rank for dlrt optimizer (default: 0.3)')
    parser.add_argument('--wd', type=float, default=0.001, metavar='wd',
                        help='weight decay on S and weights')
    parser.add_argument('--momentum', type=float, default=0.1, metavar='MOMENTUM', help='momentum (default: 0.1)')
    parser.add_argument('--workers', type=int, default=1, metavar='WORKERS',
                        help='number of workers for the dataloaders (default: 1)')
    # Arguments for network save n load
    parser.add_argument('--save_weights', type=bool, default=False, metavar='SAVE_WEIGHTS',
                        help='save the weights of the best validation model during the run (default: True)')
    parser.add_argument('--save_progress', type=bool, default=False, help='save progress csv (TEST)')
    parser.add_argument('--load_weights', type=bool, default=False, help='load standard weights for the model (TEST)')
    parser.add_argument('--load_model_path', type=str, default=None, metavar='LOAD_MODEL_PATH',
                        help='Loads the model given the full path including the filename. Basepath is where main.py is '
                             'located.'
                             'The user needs to take care to load the correct model for the dataset (default: None)')

    parser.add_argument("--net_name", default='lenet5',
                        choices=["lenet5", "vgg16",  'alexnet'])
    parser.add_argument("--dataset_name", default='mnist',
                        choices=["mnist", "cifar10", "fashion_mnist"])
    parser.add_argument('--cv_run', type=int, default=0,
                        help='number of cross validation run to add to savename (default: 0)')

    # Arguments for Low-Rank Discretization
    parser.add_argument('--chain_init', type=bool, default=False, help='add chain initialization (TEST)')
    parser.add_argument('--tucker', type=bool, default=False, help='add tucker convolution (TEST)')
    parser.add_argument('--adaptive', type=bool, default=False, help='add tucker convolution (TEST)')
    parser.add_argument('--mat_dlrt', type=bool, default=False,
                        help='add matrix linear dlrt layers (TEST) (default: True)')
    parser.add_argument('--baseline', type=bool, default=False,
                        help='add matrix linear dlrt layers (TEST) (default: True)')
    parser.add_argument('--device', type=str, default='cuda', help='device (cuda or cpu)')
    # Misc Arguments
    parser.add_argument('--profiler', type=bool, default=False, help='toggle Timing profiler')
    parser.add_argument('--datapath', type=str, default="../../../../data02/", help='path to folder')

    args = parser.parse_args()

    # setup cuda
    device = args.device if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
    args.device = device
    baseline = args.baseline

    def accuracy(outputs, labels):
        return torch.sum(torch.tensor(torch.argmax(outputs.detach(), axis=1) == labels, dtype=torch.float16))

    criterion = torch.nn.CrossEntropyLoss()

    # -------- Network Selection -----------
    f = choose_model(model_name=args.net_name, baseline=args.baseline, tucker=args.tucker, mat_dlrt=args.mat_dlrt,
                     adaptive=args.adaptive, tau=args.tau, device=args.device, chain_init=args.chain_init,
                     dataset_name=args.dataset_name, load_model_path=args.load_model_path,load_weights=args.load_weights)
    
    
    # -------- Dataset Selection -----------
    train_loader, val_loader, test_loader = choose_dataset(dataset_name=args.dataset_name, batch_size=args.batch_size,
                                                           num_workers=args.workers, datapath=args.datapath)

    # -------- Optimizer Selection ---------
    optimizer = DLRT_Optimizer(f, baseline=baseline, lr=args.lr, momentum=args.momentum, wd=args.wd)

    # -------- LR Scheduler Selection ---------
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer.integrator, factor=0.1, patience=5)

    # -------- Trainer Selection ---------

    path = f'../results/'
    save_name = f'{args.net_name}_{args.dataset_name}_tau{args.tau}_baseline{args.baseline}_mom{args.momentum}_lr{args.lr}_cv{args.cv_run}'

    if args.profiler:
        profiler_trainer.train(f, optimizer=optimizer, criterion=criterion, train_loader=train_loader,
                               validation_loader=val_loader, test_loader=test_loader, metric=accuracy,
                               epochs=args.epochs, device=args.device, path=path,
                               save_weights=args.save_weights,
                               save_progress=args.save_progress, scheduler=scheduler, save_name=save_name)

        if args.epochs_ft != 0:
            print(f'START FINE TUNING')
            print('=' * 40)
            profiler_trainer.train(f, optimizer=optimizer, criterion=criterion, train_loader=train_loader,
                                   validation_loader=val_loader, test_loader=test_loader, metric=accuracy,
                                   epochs=args.epochs_ft, device=args.device, path=path, save_weights=args.save_weights,
                                   fine_tune=True, save_progress=args.save_progress, scheduler=scheduler,
                                   save_name=save_name)
    else:

        print(
            f'TRAINING {args.net_name} ON {args.dataset_name} with parameters lr{args.lr},tau {args.tau},baseline {args.baseline}')
        trainer.train(f, optimizer=optimizer, criterion=criterion, train_loader=train_loader,
                      epoch_status_bar=True, validation_loader=val_loader, test_loader=test_loader,
                      metric=accuracy, epochs=args.epochs, device=args.device, path=path,
                      save_weights=args.save_weights, save_progress=args.save_progress,
                      scheduler=scheduler, save_name=save_name)

        if args.epochs_ft != 0:
            print(f'START FINE TUNING')
            print('=' * 40)
            trainer.train(f, optimizer=optimizer, criterion=criterion, train_loader=train_loader,
                          validation_loader=val_loader, test_loader=test_loader, metric=accuracy, epochs=args.epochs_ft,
                          device=args.device, path=path, save_weights=args.save_weights, fine_tune=True,
                          save_progress=args.save_progress, scheduler=scheduler, save_name=save_name,
                          epoch_status_bar=True)


if __name__ == '__main__':

    main()
